import numpy as np
from PIL import Image
import torch
from typing import Callable, List, Tuple, Optional
from sklearn.decomposition import NMF
from ..activations_and_gradients import ActivationsAndGradients
from ..utils.image import (
    scale_cam_image,
    create_labels_legend,
    show_factorization_on_image,
)


def dff(activations: np.ndarray, n_components: int = 5):
    """Compute Deep Feature Factorization on a 2d Activations tensor.

    :param activations: A numpy array of shape batch x channels x height x width
    :param n_components: The number of components for the non negative matrix factorization
    :returns: A tuple of the concepts (a numpy array with shape channels x components),
              and the explanation heatmaps (a numpy arary with shape batch x height x width)
    """

    batch_size, channels, h, w = activations.shape
    reshaped_activations = activations.transpose((1, 0, 2, 3))
    reshaped_activations[np.isnan(reshaped_activations)] = 0
    reshaped_activations = reshaped_activations.reshape(
        reshaped_activations.shape[0], -1
    )
    offset = reshaped_activations.min(axis=-1)
    reshaped_activations = reshaped_activations - offset[:, None]

    model = NMF(n_components=n_components, init="random", random_state=0)
    W = model.fit_transform(reshaped_activations)
    H = model.components_
    concepts = W + offset[:, None]
    explanations = H.reshape(n_components, batch_size, h, w)
    explanations = explanations.transpose((1, 0, 2, 3))
    return concepts, explanations


class DeepFeatureFactorization:
    """Deep Feature Factorization: https://arxiv.org/abs/1806.10206
    This gets a model andcomputes the 2D activations for a target layer,
    and computes Non Negative Matrix Factorization on the activations.

    Optionally it runs a computation on the concept embeddings,
    like running a classifier on them.

    The explanation heatmaps are scalled to the range [0, 1]
    and to the input tensor width and height.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        target_layer: torch.nn.Module,
        reshape_transform: Callable = None,
        computation_on_concepts=None,
    ):
        self.model = model
        self.computation_on_concepts = computation_on_concepts
        self.activations_and_grads = ActivationsAndGradients(
            self.model, [target_layer], reshape_transform
        )

    def __call__(self, input_tensor: torch.Tensor, n_components: int = 16):
        batch_size, channels, h, w = input_tensor.size()
        _ = self.activations_and_grads(input_tensor)

        with torch.no_grad():
            activations = self.activations_and_grads.activations[0].cpu().numpy()

        concepts, explanations = dff(activations, n_components=n_components)

        processed_explanations = []

        for batch in explanations:
            processed_explanations.append(scale_cam_image(batch, (w, h)))

        if self.computation_on_concepts:
            with torch.no_grad():
                concept_tensors = torch.from_numpy(
                    np.float32(concepts).transpose((1, 0))
                )
                concept_outputs = (
                    self.computation_on_concepts(concept_tensors).cpu().numpy()
                )
            return concepts, processed_explanations, concept_outputs
        else:
            return concepts, processed_explanations

    def __del__(self):
        self.activations_and_grads.release()

    def __exit__(self, exc_type, exc_value, exc_tb):
        self.activations_and_grads.release()
        if isinstance(exc_value, IndexError):
            # Handle IndexError here...
            print(
                f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}"
            )
            return True


def run_dff_on_image(
    model: torch.nn.Module,
    target_layer: torch.nn.Module,
    classifier: torch.nn.Module,
    img_pil: Image,
    img_tensor: torch.Tensor,
    reshape_transform=Optional[Callable],
    n_components: int = 5,
    top_k: int = 2,
) -> np.ndarray:
    """Helper function to create a Deep Feature Factorization visualization for a single image.
    TBD: Run this on a batch with several images.
    """
    rgb_img_float = np.array(img_pil) / 255
    dff = DeepFeatureFactorization(
        model=model,
        reshape_transform=reshape_transform,
        target_layer=target_layer,
        computation_on_concepts=classifier,
    )

    concepts, batch_explanations, concept_outputs = dff(
        img_tensor[None, :], n_components
    )

    concept_outputs = torch.softmax(torch.from_numpy(concept_outputs), axis=-1).numpy()
    concept_label_strings = create_labels_legend(
        concept_outputs, labels=model.config.id2label, top_k=top_k
    )
    visualization = show_factorization_on_image(
        rgb_img_float,
        batch_explanations[0],
        image_weight=0.3,
        concept_labels=concept_label_strings,
    )

    result = np.hstack((np.array(img_pil), visualization))
    return result
